# !pip install miditok
# !pip install symusic
# !pip install glob
# !pip install torch
# !pip install pretty_midi
# !pip install midi2audio
import pretty_midi
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from miditok.pytorch_data import DatasetMIDI, DataCollator
import glob
from miditok import REMI, TokenizerConfig
from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast
import matplotlib.pyplot as plt
%matplotlib inline
import math
NESMDB_PATH = "./nesmdb_midi/"
midi_data = pretty_midi.PrettyMIDI(NESMDB_PATH + 'train/297_SkyKid_00_01StartMusicBGMIntroBGM.mid')
for instrument in midi_data.instruments:
print('-' * 80)
print(instrument.name.upper())
print('# note events: {}'.format(len(instrument.notes)))
print('# control change events: {}'.format(len(instrument.control_changes)))
Path to dataset files: /home/josh/.cache/kagglehub/datasets/imsparsh/lakh-midi-clean/versions/1 -------------------------------------------------------------------------------- P1 # note events: 158 # control change events: 221 -------------------------------------------------------------------------------- P2 # note events: 197 # control change events: 73 -------------------------------------------------------------------------------- TR # note events: 123 # control change events: 0 -------------------------------------------------------------------------------- NO # note events: 6 # control change events: 164
train_files = glob.glob(NESMDB_PATH + "train/*.mid")
test_files = glob.glob(NESMDB_PATH + "test/*.mid")
config = TokenizerConfig(
use_time_signatures=True,
use_tempos=True,
use_programs=True,
num_velocities=127,
ac_polyphony_track = True,
ac_polyphony_bar = True,
)
tokenizer = REMI(config)
train_dataset = DatasetMIDI(
files_paths=train_files,
tokenizer=tokenizer,
max_seq_len=1024,
bos_token_id=tokenizer["BOS_None"],
eos_token_id=tokenizer["EOS_None"],
)
test_dataset = DatasetMIDI(
files_paths=test_files,
tokenizer=tokenizer,
max_seq_len=1024,
bos_token_id=tokenizer["BOS_None"],
eos_token_id=tokenizer["EOS_None"],
)
/home/josh/miniconda3/envs/gpu-env/lib/python3.11/site-packages/miditok/tokenizations/remi.py:88: UserWarning: Attribute controls are not compatible with 'config.one_token_stream_for_programs' and multi-vocabulary tokenizers. Disabling them from the config. super().__init__(tokenizer_config, params)
input_dir = "./nesmdb_midi/train/"
collator = DataCollator(tokenizer.pad_token_id)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collator, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collator, num_workers=4)
len(train_loader), len(test_loader)
(1126, 94)
class MusicGRU(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
super(MusicGRU, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.norm = nn.LayerNorm(embedding_dim)
self.gru = nn.GRU(
input_size=embedding_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=0.2
)
self.fc = nn.Linear(hidden_dim, vocab_size)
def forward(self, x, hidden=None):
x = self.norm(self.embedding(x))
out, hidden = self.gru(x, hidden)
out = self.fc(out)
return out, hidden
from torch.amp import GradScaler, autocast
def train(model, train_loader, val_loader, vocab_size, num_epochs=10, lr=0.001, device='cuda'):
model = model.to(device)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.AdamW(model.parameters(), 3e-4, weight_decay=1e-2)
scaler = GradScaler('cuda')
for epoch in range(num_epochs):
# Training
model.train()
total_train_loss = 0
for batch in train_loader:
batch = batch['input_ids'].to(device) # (batch_size, seq_length)
inputs = batch[:, :-1]
targets = batch[:, 1:]
optimizer.zero_grad()
with autocast('cuda'):
outputs, _ = model(inputs)
outputs = outputs.reshape(-1, vocab_size)
targets = targets.reshape(-1)
loss = criterion(outputs, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
total_train_loss += loss.item()
avg_train_loss = total_train_loss / len(train_loader)
# Validation
model.eval()
total_val_loss = 0
with torch.no_grad():
for batch in val_loader:
batch = batch['input_ids'].to(device)
inputs = batch[:, :-1]
targets = batch[:, 1:]
outputs, _ = model(inputs)
outputs = outputs.reshape(-1, vocab_size)
targets = targets.reshape(-1)
loss = criterion(outputs, targets)
total_val_loss += loss.item()
avg_val_loss = total_val_loss / len(val_loader)
print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
if __name__ == "__main__":
vocab_size = tokenizer.vocab_size
embedding_dim = 128
hidden_dim = 512
num_layers = 2
model = MusicGRU(vocab_size, embedding_dim, hidden_dim, num_layers)
train(model, train_loader, test_loader, vocab_size)
Epoch 1/10 | Train Loss: 2.4395 | Val Loss: 1.9650 Epoch 2/10 | Train Loss: 2.0247 | Val Loss: 1.8584 Epoch 3/10 | Train Loss: 1.9369 | Val Loss: 1.7829 Epoch 4/10 | Train Loss: 1.8575 | Val Loss: 1.7195 Epoch 5/10 | Train Loss: 1.7932 | Val Loss: 1.6797 Epoch 6/10 | Train Loss: 1.7512 | Val Loss: 1.6502 Epoch 7/10 | Train Loss: 1.7142 | Val Loss: 1.6344 Epoch 8/10 | Train Loss: 1.6886 | Val Loss: 1.6180 Epoch 9/10 | Train Loss: 1.6708 | Val Loss: 1.6053 Epoch 10/10 | Train Loss: 1.6510 | Val Loss: 1.5989
def sample(model, start_token, max_length=100, temperature=0.8, device='cuda'):
model = model.to(device)
model.eval()
generated = [start_token]
input_token = torch.tensor([[start_token]], device=device) # (1, 1)
hidden = None
for _ in range(max_length):
output, hidden = model(input_token, hidden) # output: (1, 1, vocab_size)
output = output[:, -1, :] # take the last output
output = output / temperature # adjust randomness
probs = F.softmax(output, dim=-1) # (1, vocab_size)
next_token = torch.multinomial(probs, num_samples=1).item()
generated.append(next_token)
if next_token == 2 or next_token == 0: # reach end of sequence
break
input_token = torch.tensor([[next_token]], device=device)
return generated
start_token = tokenizer.special_tokens_ids[1]
generated_sequence = sample(model, start_token, max_length=2048)
print("Generated token sequence:")
print(generated_sequence)
import midi2audio
from midi2audio import FluidSynth
from IPython.display import Audio, display
fs = FluidSynth("FluidR3Mono_GM.sf3")
output_score = tokenizer.tokens_to_midi(generated_sequence)
print(type(output_score))
# boost all note velocities
for track in output_score.tracks:
for note in track.notes:
note.velocity = min(127, max(60, int(note.velocity * 2)))
output_score.dump_midi(f"rnn.mid")
fs.midi_to_audio("rnn.mid", "rnn.wav")
display(Audio("rnn.wav"))
output_pm = pretty_midi.PrettyMIDI("rnn.mid")
for i, inst in enumerate(output_pm.instruments):
pr = inst.get_piano_roll(fs=100)
plt.imshow(pr, aspect='auto', origin='lower', alpha=0.5, cmap='hot')
Generated token sequence: [1, 4, 610, 284, 453, 553, 36, 107, 223, 554, 34, 107, 223, 511, 36, 93, 221, 288, 553, 36, 107, 221, 554, 35, 107, 221, 511, 36, 93, 221, 290, 553, 37, 107, 221, 554, 36, 107, 221, 511, 36, 93, 221, 293, 553, 38, 107, 221, 554, 35, 107, 221, 511, 38, 93, 221, 295, 553, 40, 107, 221, 554, 36, 107, 221, 511, 36, 93, 221, 297, 553, 45, 107, 221, 554, 398, 107, 221, 511, 37, 93, 221, 299, 553, 43, 107, 221, 554, 56, 107, 221, 511, 36, 93, 221, 301, 553, 43, 107, 221, 554, 57, 107, 221, 511, 91, 93, 220, 511, 40, 93, 221, 305, 553, 38, 107, 572, 554, 59, 107, 221, 511, 38, 93, 221, 307, 554, 59, 107, 221, 511, 34, 93, 221, 308, 554, 55, 107, 221, 511, 591, 93, 221, 310, 554, 57, 107, 221, 511, 34, 93, 221, 312, 554, 57, 107, 221, 511, 31, 93, 221, 314, 554, 62, 107, 221, 511, 27, 93, 221, 4, 610, 284, 554, 51, 107, 221, 511, 24, 93, 221, 286, 554, 47, 107, 221, 511, 27, 93, 221, 289, 554, 50, 107, 221, 511, 34, 93, 221, 290, 554, 41, 107, 221, 511, 117, 93, 221, 291, 554, 39, 107, 221, 511, 24, 93, 221, 293, 554, 39, 107, 221, 295, 554, 36, 107, 221, 511, 20, 93, 221, 297, 554, 36, 107, 221, 511, 20, 531, 221, 300, 554, 38, 107, 221, 511, 18, 93, 220, 302, 554, 38, 107, 221, 511, 18, 93, 220, 303, 554, 36, 107, 221, 511, 32, 93, 221, 305, 554, 36, 107, 221, 523, 511, 31, 93, 221, 308, 554, 42, 107, 221, 511, 20, 93, 221, 311, 554, 40, 107, 221, 511, 33, 93, 221, 314, 554, 39, 107, 221, 511, 27, 93, 221, 4, 610, 284, 554, 39, 107, 221, 511, 27, 93, 221, 286, 554, 39, 107, 221, 511, 28, 93, 220, 288, 554, 39, 107, 221, 511, 28, 93, 221, 290, 554, 39, 107, 221, 511, 27, 93, 220, 293, 554, 40, 107, 221, 511, 27, 93, 220, 297, 554, 40, 107, 221, 511, 29, 93, 221, 300, 554, 39, 107, 221, 511, 27, 93, 221, 303, 554, 39, 107, 221, 511, 28, 93, 220, 306, 554, 38, 107, 221, 511, 28, 93, 221, 308, 554, 35, 107, 221, 511, 24, 93, 220, 310, 554, 35, 107, 221, 511, 27, 93, 221, 312, 554, 32, 107, 221, 511, 28, 93, 220, 314, 554, 39, 107, 221, 511, 31, 93, 221, 4, 610, 284, 554, 40, 107, 221, 511, 27, 93, 220, 286, 554, 38, 107, 221, 511, 24, 93, 221, 288, 554, 38, 107, 221, 511, 27, 93, 221, 290, 554, 39, 107, 221, 511, 27, 93, 225, 293, 554, 42, 107, 221, 295, 554, 41, 107, 221, 297, 554, 35, 107, 221, 511, 24, 93, 221, 299, 554, 39, 107, 221, 511, 27, 93, 220, 301, 554, 39, 107, 221, 511, 27, 93, 220, 303, 554, 39, 107, 221, 511, 36, 93, 479, 307, 554, 39, 107, 221, 310, 554, 37, 107, 221, 312, 554, 29, 107, 221, 314, 554, 39, 107, 221, 4, 610, 284, 554, 39, 107, 221, 511, 27, 93, 220, 286, 554, 39, 107, 221, 511, 32, 93, 220, 445, 554, 39, 107, 221, 511, 27, 93, 220, 290, 554, 38, 107, 223, 511, 27, 93, 220, 439, 511, 27, 93, 220, 297, 553, 51, 107, 220, 553, 39, 107, 221, 491, 554, 38, 107, 221, 511, 27, 93, 220, 298, 553, 40, 107, 221, 451, 554, 39, 107, 221, 511, 27, 93, 220, 301, 553, 39, 107, 221, 511, 36, 93, 220, 303, 554, 41, 107, 221, 511, 29, 93, 220, 306, 553, 48, 107, 221, 554, 44, 107, 221, 511, 27, 93, 220, 308, 553, 51, 107, 221, 511, 27, 93, 220, 310, 554, 40, 107, 221, 182, 553, 51, 107, 221, 511, 27, 93, 220, 312, 554, 39, 107, 221, 314, 553, 49, 107, 221, 554, 44, 107, 221, 511, 32, 93, 220, 4, 610, 284, 553, 48, 107, 221, 554, 39, 107, 221, 511, 27, 93, 220, 285, 553, 51, 107, 221, 554, 39, 107, 221, 511, 27, 93, 220, 287, 553, 49, 107, 221, 554, 39, 107, 221, 511, 27, 93, 220, 288, 553, 51, 107, 221, 554, 39, 107, 221, 511, 27, 93, 220, 290, 36, 553, 51, 107, 221, 554, 39, 107, 221, 511, 27, 93, 220, 294, 553, 56, 107, 221, 554, 39, 107, 221, 511, 27, 93, 220, 297, 553, 56, 107, 221, 554, 36, 96, 221, 511, 27, 93, 220, 299, 553, 60, 107, 221, 554, 39, 107, 221, 511, 27, 93, 220, 301, 553, 58, 107, 221, 554, 43, 107, 221, 511, 27, 93, 220, 303, 554, 39, 107, 221, 511, 24, 93, 220, 305, 554, 39, 107, 221, 511, 27, 93, 220, 2] <class 'symusic.core.ScoreTick'> FluidSynth runtime version 2.3.4 Copyright (C) 2000-2023 Peter Hanappe and others. Distributed under the LGPL license. SoundFont(R) is a registered trademark of Creative Technology Ltd. Rendering audio to file 'rnn.wav'..
/tmp/ipykernel_3953/145865254.py:12: UserWarning: miditok: The `tokens_to_midi` method had been renamed `decode`. It is now depreciated and will be removed in future updates. output_score = tokenizer.tokens_to_midi(generated_sequence) fluidsynth: error: fluid_is_soundfont(): fopen() failed: 'File does not exist.' Parameter 'FluidR3Mono_GM.sf3' not a SoundFont or MIDI file or error occurred identifying it.
train_loss = 1.6510
val_loss = 1.5989
train_perplexity = math.exp(train_loss) # ≈ 5.21
val_perplexity = math.exp(val_loss) # ≈ 4.95
print(f"Model Train Loss: {train_loss:.2f}")
print(f"Model Train Perplexity: {train_perplexity:.2f}")
print(f"Model Validation Loss: {val_loss:.2f}")
print(f"Model Validation Perplexity: {val_perplexity:.2f}")
Model Train Loss: 1.65 Model Train Perplexity: 5.21 Model Validation Loss: 1.60 Model Validation Perplexity: 4.95
def generate_random_tokens(vocab_size, max_length=512):
return torch.randint(low=0, high=vocab_size, size=(max_length,), dtype=torch.long).tolist()
def generate_random_sequence(tokenizer, max_length=512):
vocab_size = tokenizer.vocab_size
forbidden = {tokenizer["PAD_None"], tokenizer["EOS_None"]}
random_tokens = []
while len(random_tokens) < max_length - 1:
token = torch.randint(0, vocab_size, (1,)).item()
if token not in forbidden:
random_tokens.append(token)
return [tokenizer["BOS_None"]] + random_tokens
random_tokens = generate_random_sequence(tokenizer, max_length=512)
rand_score = tokenizer.tokens_to_midi(random_tokens)
# boost all note velocities
for track in rand_score.tracks:
for note in track.notes:
note.velocity = min(127, max(60, int(note.velocity * 2)))
rand_score.dump_midi(f"rando.mid")
fs.midi_to_audio("rando.mid", "rando.wav")
display(Audio("rando.wav"))
rand_pm = pretty_midi.PrettyMIDI("rando.mid")
for i, inst in enumerate(rand_pm.instruments):
pr = inst.get_piano_roll(fs=100)
plt.imshow(pr, aspect='auto', origin='lower', alpha=0.5, cmap='hot')
FluidSynth runtime version 2.3.4 Copyright (C) 2000-2023 Peter Hanappe and others. Distributed under the LGPL license. SoundFont(R) is a registered trademark of Creative Technology Ltd. Rendering audio to file 'rando.wav'..
/tmp/ipykernel_3953/197685182.py:2: UserWarning: miditok: The `tokens_to_midi` method had been renamed `decode`. It is now depreciated and will be removed in future updates. rand_score = tokenizer.tokens_to_midi(random_tokens) fluidsynth: error: fluid_is_soundfont(): fopen() failed: 'File does not exist.' Parameter 'FluidR3Mono_GM.sf3' not a SoundFont or MIDI file or error occurred identifying it.
baseline_vocab = [i for i in range(vocab_size) if i not in [tokenizer["PAD_None"], tokenizer["EOS_None"]]]
V = len(baseline_vocab)
uniform_crossentropyloss = math.log(V)
uniform_perplexity = math.exp(uniform_crossentropyloss)
print(f"Baseline Loss: {uniform_crossentropyloss:.2f}")
print(f"Baseline Perplexity: {uniform_perplexity:.2f}")
Baseline Loss: 6.41 Baseline Perplexity: 609.00